Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[eplatero] Add support for exporting and compiling models for SpD #119

Merged
merged 30 commits into from
Dec 11, 2024

Conversation

quic-agokhale
Copy link
Contributor

@quic-agokhale quic-agokhale commented Sep 19, 2024

This change has been validated and posted on behalf of Erick Platero.

It adds support for generating a Target LM to run as a verifier model by outputting all logits instead of just that of the last position for the input sequence.

It also allows compiling the Target and Draft LMs with specializations that support SpD

Usage:

TLM:
tlm = QEFFAutoModelForCausalLM.from_pretrained()
tlm.transform(num_speculative_tokens=)
tlm.export_and_compile()

DLM:
dlm = QEFFAutoModelForCausalLM.from_pretrained()
dlm.transform(is_dlm=True)
dlm.export_and_compile()

@vbaddi
Copy link
Contributor

vbaddi commented Sep 20, 2024

Linter checks and DCO is failing:
Could you please do the following and repush to format the code:

- pip install pre-commit
- pre-commit install
- git commit -m ...

@eplatero97
Copy link
Contributor

Unit Tests

just added unit tests with below results:

(qeff_env) eplatero@aus121-r760-0:/prj/crd/austin/validation/scratch/users/eplatero/qefficient_spd/efficient-transformers$ pytest tests/spd/test_tlm_dlm_export_and_compile.py
================================================================================================================================= test session starts ==================================================================================================================================
platform linux -- Python 3.8.20, pytest-8.3.3, pluggy-1.5.0 -- /prj/crd/austin/validation/scratch/users/eplatero/qefficient_spd/efficient-transformers/qeff_env/bin/python3.8
cachedir: .pytest_cache
rootdir: /prj/crd/austin/validation/scratch/users/eplatero/qefficient_spd/efficient-transformers
configfile: pyproject.toml
collected 2 items

tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[llama] WARNING - QEfficient - Updating attn_implementation to be 'eager', got None
Fetching 7 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 96579.37it/s]
WARNING - QEfficient - Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================


=============== PyTorch vs. fp32 ONNXRT (MAD) ===============

logits           1.33514404296875e-05
past_keys (mean)                 6.141860715367577e-07
past_value (mean)                4.351139068603516e-06

=====================================================================

Running AI 100 compiler: /opt/qti-aic/exec/qaic-exec -m=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/TinyLlama_TinyLlama-1.1B-Chat-v1.0_kv.onnx -aic-hw -aic-hw-version=2.0 -network-specialization-config=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/specializations.json -convert-to-fp16 -retained-state -aic-num-cores=16 -custom-IO-list-file=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/custom_io_int8.yaml -compile-only -aic-binary-dir=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/qpcs -mxfp6-matmul
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

===================== Compilation Done! =====================

PASSED
tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[llama] WARNING - QEfficient - Updating attn_implementation to be 'eager', got None
Fetching 7 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 52996.62it/s]
WARNING - QEfficient - Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================


=============== PyTorch vs. fp32 ONNXRT (MAD) ===============

logits           1.33514404296875e-05
past_keys (mean)                 6.141860715367577e-07
past_value (mean)                4.351139068603516e-06

=====================================================================

Running AI 100 compiler: /opt/qti-aic/exec/qaic-exec -m=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/TinyLlama_TinyLlama-1.1B-Chat-v1.0_kv.onnx -aic-hw -aic-hw-version=2.0 -network-specialization-config=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/specializations.json -convert-to-fp16 -retained-state -aic-num-cores=16 -custom-IO-list-file=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/custom_io_int8.yaml -compile-only -aic-binary-dir=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/qpcs -mxfp6-matmul
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

===================== Compilation Done! =====================

PASSED

============================================================================================================================ 2 passed in 551.93s (0:09:11) =============================================================================================================================

API

to integrate SpD changes, changed the api slightly from:

# tlm
tlm = QEFFAutoModelForCausalLM.from_pretrained()
tlm.transform(num_speculative_tokens=)
tlm.export_and_compile()

# dlm
dlm = QEFFAutoModelForCausalLM.from_pretrained()
dlm.transform(is_dlm=True)
dlm.export_and_compile()

to

# tlm
tlm = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_speculative_tokens=)
tlm.export_and_compile()
# dlm
dlm = QEFFAutoModelForCausalLM.from_pretrained(model_name, is_dlm=True)
dlm.export_and_compile()

did this change because from_pretrained() automatically calls the transform function, which then sets the is_transformed member variable to True. Thus, this does it all in one step.

Next Steps

Once llama changes have been approved, the plan is to make corresponding changes to the rest of the supported models along with their unit testing.

also, we are still discussing where to best put a documentation for these SpD changes... we are discussing maybe updating the transform doc with two parameters: num_speculative_tokens and is_dlm or possibly adding a new document explaining this... would appreciate y'alls thoughts on this

@vbaddi
Copy link
Contributor

vbaddi commented Sep 30, 2024

@ochougul @irajagop @quic-rishinr Could you all pls review this PR.

{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
]
}
# Create specialization cfgs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify specialization creation?
only change when num_speculative_tokens is provided is decode_seq_len. Can it be done something like

decode_seq_len = 1 if num_speculative_tokens is None else num_speculative_tokens + 1
    specialization = [
        {"batch_size": str(batch_size), "seq_len": str(prompt_len), "ctx_len": str(ctx_len)},
        {"batch_size": str(batch_size), "seq_len": str(decode_seq_len), "ctx_len": str(ctx_len)}
    ]
if is_dlm:
        specialization.append(
            {"batch_size": str(batch_size), "seq_len": "2", "ctx_len": str(ctx_len)}
        )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion, just condensed it. let me know what you think

QEfficient/transformers/modeling_spd_utils.py Outdated Show resolved Hide resolved
@@ -188,6 +192,19 @@ def transform(self, **kwargs):
if isinstance(self.model.config.quantization_config, QEffGPTQConfig):
self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform)

num_speculative_tokens = kwargs.get("num_speculative_tokens", None)
is_dlm = kwargs.get("is_dlm", False)
assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise a ValueError instead of assert?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, added ValueError instead of AssertionError

input_ids = pt_outputs["logits"][:, -1].argmax(-1).reshape(-1, 1)
else:
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
pt_outputs["input_ids"] = input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please confirm whether it should be pt_outputs["input_ids"] = input_ids or updated_inputs["input_ids"] = input_ids

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good catch. it was indeed updated_inputs. tomorrow I will add unit tests that checks SpD functionality for both, CB and non-CB models to make sure it catches this mistake and any other possible ones

@eplatero97
Copy link
Contributor

eplatero97 commented Oct 1, 2024

Feedback

@quic-rishinr, thank you for the feedback. I have updated the changes. please let me know what you think.

I explicitly added num_speculative_tokens and is_dlm to transform method to create some documentation on how to create SpD model.

Validation

Validation to show unit tests are passing for CB SpD are below:

=============================================================================================================================================================================== test session starts ===============================================================================================================================================================================
platform linux -- Python 3.8.20, pytest-8.3.3, pluggy-1.5.0 -- /prj/crd/austin/validation/scratch/users/eplatero/qefficient_spd/efficient-transformers/qeff_env/bin/python3.8
cachedir: .pytest_cache
rootdir: /prj/crd/austin/validation/scratch/users/eplatero/qefficient_spd/efficient-transformers
configfile: pyproject.toml
collected 2 items                                                                                                                                                                                                                                                                                                                                                                 

tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[llama] WARNING - QEfficient - Updating attn_implementation to be 'eager', got None
Fetching 7 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 126009.13it/s]
WARNING - QEfficient - Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================


=============== PyTorch vs. fp32 ONNXRT (MAD) ===============

logits           1.33514404296875e-05
past_keys (mean)                 6.141860715367577e-07
past_value (mean)                4.351139068603516e-06

=====================================================================

Running AI 100 compiler: /opt/qti-aic/exec/qaic-exec -m=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/TinyLlama_TinyLlama-1.1B-Chat-v1.0_kv.onnx -aic-hw -aic-hw-version=2.0 -network-specialization-config=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/specializations.json -convert-to-fp16 -retained-state -aic-num-cores=16 -custom-IO-list-file=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/custom_io_int8.yaml -compile-only -aic-binary-dir=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/qpcs -mxfp6-matmul
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

===================== Compilation Done! =====================

PASSED
tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[llama] WARNING - QEfficient - Updating attn_implementation to be 'eager', got None
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 51964.83it/s]
WARNING - QEfficient - Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
============== Diagnostic Run torch.onnx.export version 2.0.0+cpu ==============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================


=============== PyTorch vs. fp32 ONNXRT (MAD) ===============

logits           1.33514404296875e-05
past_keys (mean)                 6.141860715367577e-07
past_value (mean)                4.351139068603516e-06

=====================================================================

Running AI 100 compiler: /opt/qti-aic/exec/qaic-exec -m=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/TinyLlama_TinyLlama-1.1B-Chat-v1.0_kv.onnx -aic-hw -aic-hw-version=2.0 -network-specialization-config=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/specializations.json -convert-to-fp16 -retained-state -aic-num-cores=16 -custom-IO-list-file=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx/custom_io_int8.yaml -compile-only -aic-binary-dir=/local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/qpc_16cores_1bs_32pl_128cl_-1mos_8fbs_1devices_mxfp6_mxint8/qpcs/qpcs -mxfp6-matmul
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

===================== Compilation Done! =====================

PASSED

========================================================================================================================================================================== 2 passed in 575.63s (0:09:35) ==========================================================================================================================================================================

Next Steps

tomorrow, will be adding unit tests to validate this functionality on non-CB to make sure it works as well.

@eplatero97
Copy link
Contributor

Validation

Added unit test that also tests non-CB model with SpD. Passing of all four tests are shown below:

$ pytest -rA tests/spd/test_tlm_dlm_export_and_compile.py
==================================================================================================== PASSES =====================================================================================================_______________________________________________________________________________________ test_llama_tlm_logit_dims[llama0] _______________________________________________________________________________________----------------------------------------------------------------------------------------------- Captured log call -----------------------------------------------------------------------------------------------WARNING  QEfficient:modeling_auto.py:111 Updating attn_implementation to be 'eager', got None
WARNING  QEfficient:export_hf_to_cloud_ai_100.py:354 Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
_______________________________________________________________________________________ test_llama_tlm_logit_dims[llama1] _______________________________________________________________________________________----------------------------------------------------------------------------------------------- Captured log call -----------------------------------------------------------------------------------------------WARNING  QEfficient:modeling_auto.py:111 Updating attn_implementation to be 'eager', got None
WARNING  QEfficient:export_hf_to_cloud_ai_100.py:354 Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
_______________________________________________________________________________________ test_llama_dlm_logit_dims[llama0] _______________________________________________________________________________________----------------------------------------------------------------------------------------------- Captured log call -----------------------------------------------------------------------------------------------WARNING  QEfficient:modeling_auto.py:111 Updating attn_implementation to be 'eager', got None
WARNING  QEfficient:export_hf_to_cloud_ai_100.py:354 Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
_______________________________________________________________________________________ test_llama_dlm_logit_dims[llama1] _______________________________________________________________________________________----------------------------------------------------------------------------------------------- Captured log call -----------------------------------------------------------------------------------------------WARNING  QEfficient:modeling_auto.py:111 Updating attn_implementation to be 'eager', got None
WARNING  QEfficient:export_hf_to_cloud_ai_100.py:354 Overriding /local/mnt/qt_drive/users/eplatero/qeff_cache/TinyLlama/TinyLlama-1.1B-Chat-v1.0/onnx
============================================================================================ short test summary info ============================================================================================PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[llama0]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[llama1]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[llama0]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[llama1]
======================================================================================== 4 passed in 1123.57s (0:18:43) =========================================================================================

The non-CB DLM test is essentially testing the vanilla non-CB workflow as the only thing this does is add an extra specialization.

Thus, together these unit tests cover for SpD changes as well as keeping backward compatibility.

Let me know if this is sufficient testing @quic-rishinr, @ochougul , @irajagop, @vbaddi.

Once approved, I can move on to implement SpD changes on the rest of supported models.

@quic-rishinr
Copy link
Contributor

Hi @eplatero97 SpD support hasn’t been added to the CLI APIs like infer. Could you please add support for SqD in the CLI API as well?

@eplatero97
Copy link
Contributor

@irajagop, @quic-rishinr, @ochougul
I have posted 1st draft of the changes we talked about during yesterday's meeting.

It is not complete (creating specializations with new input buffer num_logits_to_keep is missing) but most of the changes that will take place are in here.

Please take a look when you have time and let me know what you think.

@eplatero97
Copy link
Contributor

non_constant_tensor_not_supported.txt

Above patch currently passes exportation of onnx but does not pass compiler due to below error:

# see attached file for full log
Error message:  [Operator-'/Range'] : Range: Non-constant start tensor not supported. 

Above is directly caused by line 43 in modeling_spd_utils.py:

indices = torch.arange(lower_idx[0], upper_idx[0])

I'm going to assume that this is a compiler limitation and instead work on developing a patch where num_logits_to_keep is fixed between prefill and decode. this will require exporting the onnx with the user-defined num_logits_to_keep

@eplatero97
Copy link
Contributor

topk_non_constant_k_tensor_not_supported_when_passing_num_logits_to_keep_as_input.log

latest patch implements a fixed num_logits_to_keep during the onnx exportation process. I see a similar error as before by using topk function in line 61 of modeling_spd_utils.py:

Error message:  [Operator-'/TopK'] : TopK: Non-constant k tensor not supported. 

next thing to try will be to implicitly pass num_logits_to_keep as an instance variable to the model instead of the explicit pass in the forward pass

@eplatero97
Copy link
Contributor

topk_export_and_compile_pass.log

with latest patch, attached log shows how export and compiling are now passing since I'm implicitly passing num_logits_to_keep instead of passing it explicitly during forward pass.

@eplatero97
Copy link
Contributor

eplatero97 commented Oct 18, 2024

@quic-rishinr, @vbaddi, @irajagop, @ochougul, @quic-agokhale

Please review the latest patch as it contains the proposed flow.

Validation passes that export both TLM and DLM for both CB and non-CB are below:

$ pytest -rA tests/spd/test_tlm_dlm_export_and_compile.py
================================================================================================================================================================= short test summary info ==================================================================================================================================================================
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[non-CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[non-CB llama]
============================================================================================================================================================== 4 passed in 217.88s (0:03:37) ===============================================================================================================================================================

@eplatero97 eplatero97 force-pushed the gen_spd_models branch 3 times, most recently from 32ce801 to 03c8b02 Compare October 24, 2024 11:08
@eplatero97
Copy link
Contributor

eplatero97 commented Nov 6, 2024

replaced topk solution in modeling_spd_utils.py with gather approach provided by Ankit.

unit tests are all passing:

================================================================================================================================================================================================================================= short test summary info ==================================================================================================================================================================================================================================
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_tlm_logit_dims[non-CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[CB llama]
PASSED tests/spd/test_tlm_dlm_export_and_compile.py::test_llama_dlm_logit_dims[non-CB llama]
============================================================================================================================================================================================================================== 4 passed in 254.84s (0:04:14) ===============================================================================================================================================================================================================================

made sure gather operation does what is desired by running below code which returns desired num_logits_to_keep:

import torch

# Assuming 'tensor' is your input tensor
tensor = torch.tensor([
    [0,1,2,3,4,5,6,7,8,9,10,11,12],
    [0,1,2,3,4,5,6,7,8,9,-1,-1,-1],
    [0,1,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1]
])
outputs = torch.arange(1000).view(1,1,1000).repeat(3,32,1) # shape: [3, 32, 1000]

# Get the argmax along dimension 1 with keepdim=True
logit_idx = tensor.argmax(1, keepdim=True) # shape: [bsz,1]
num_logits_to_keep = 4

# Create a tensor with values from 1 to 3 (the number of consecutive numbers you want)
lower_idx = torch.where(logit_idx <= num_logits_to_keep, 0, logit_idx - num_logits_to_keep).view(-1,1) # shape: [bsz, 1]
spec_idx = torch.arange(num_logits_to_keep).view(1,-1) # shape: [1, num_logits_to_keep]
indices = torch.add(lower_idx, spec_idx).unsqueeze(2) # shape: [bsz, num_logits_to_keep, 1]
indices = indices.expand(-1, -1, outputs.size(-1)) # shape: [bsz, ,num_logits_to_keep, d_model]
hidden_states = torch.gather(outputs, dim=1, index=indices) # shape: [bsz, num_logits_to_keep, d_model]

the hidden_states object looks something like below:

tensor([[[  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999]],

        [[  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999]],

        [[  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999],
         [  0,   1,   2,  ..., 997, 998, 999]]])

ochougul and others added 16 commits December 5, 2024 05:34
…ed to regenerate ONNX for different values of num_logits_to_keep only qpc is recompiled, * ran formatter , * reorganized pytorch transforms

Signed-off-by: Onkar Chougule <[email protected]>
Signed-off-by: eplatero <[email protected]>
Signed-off-by: eplatero <[email protected]>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this test, can the shape assert statements be added to the test that you added in tests/transformers/models/test_causal_lm_models.py::def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, we don't need that test anymore

Comment on lines 127 to 130
if is_tlm:
hf_model, transformed = SpDTransform.apply(hf_model)
assert transformed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the test specific code to test function and use this as a common function that can be used for multiple tests without any test specific code.
Is it possible to move this to spd test

Comment on lines 139 to 146
inputs = dict(
input_ids=input_ids,
position_ids=torch.Tensor([range(input_ids.shape[1])]).long(),
past_key_values=tuple(past_key_values),
output_hidden_states=True,
)
if is_tlm:
inputs["num_logits_to_keep"] = torch.zeros((input_len, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can take inputs as input to this function and move the test specific inputs creation code to the respective test.

Comment on lines 1069 to 1073

def validate_tlm_gen_tokens(self):
gen_len = (self.generated_ids)
self.prefill_seq_len

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to remove this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, just removed it 👍

# Find whether SpD acceptance rate is fully matching (100% acceptance rate).
is_matching = None
if self.is_tlm:
is_matching = self._qaic_model.is_spd_acceptance_rate_fully_matching()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check this only in testing code, not in normal execution code.

@@ -51,7 +51,7 @@ def __repr__(self) -> str:

@classmethod
@with_replaced_quantizers
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str, is_tlm: bool = False, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not required to change this

Signed-off-by: Onkar Chougule <[email protected]>
@ochougul ochougul merged commit f59a988 into quic:main Dec 11, 2024
4 checks passed
To export both DLM/TLM, add below flags to `from_pretrained`:

```Python
tlm_name = "meta-llama/Llama-3.1-405B"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a smaller model which we are using for testing?

import numpy as np
import torch

from QEfficient.utils import get_num_layers_from_config, get_padding_shape_from_config, padding_check_and_fix


class InputHandler:
def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size):
def __init__(
self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set the default values to None. Otherwise, an error will occur when creating an instance of the InputHandler class without providing num_logits_to_keep.

if self.num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": self.num_speculative_tokens}))

if self.is_dlm:
compile_hash.update(to_hashable({"is_dlm": self.is_dlm}))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update it in #176 after merging this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request in-review Review process is ongoing vllm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants